
import abc
from contextlib import contextmanager
from collections import OrderedDict

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from softlearning.utils.keras import PicklableSequential

from softlearning.models.feedforward import feedforward_model, feedforward_model_with_sn
from softlearning.models.utils import flatten_input_structure, create_inputs
from softlearning.utils.tensorflow import nest

from serializable import Serializable

tfkl = tf.keras.layers

class BaseCritic(Serializable):
    def __init__(self,
                 input_shapes,
                 observation_keys,
                 output_shapes=1,
                 preprocessors=None,
                 ):
        self._name = None
        self._input_shapes = input_shapes
        self._output_shapes = output_shapes
        self._observation_keys = observation_keys

        inputs_flat = create_inputs(input_shapes)
        preprocessors_flat = (
            flatten_input_structure(preprocessors)
            if preprocessors is not None
            else tuple(None for _ in inputs_flat))

        assert len(inputs_flat) == len(preprocessors_flat), (
            inputs_flat, preprocessors_flat)

        preprocessed_inputs = [
            preprocessor(input_) if preprocessor is not None else input_
            for preprocessor, input_
            in zip(preprocessors_flat, inputs_flat)
        ]

        def cast_and_concat(x):
            x = nest.map_structure(
                lambda element: tf.cast(element, tf.float32), x)
            x = nest.flatten(x)
            x = tf.concat(x, axis=-1)
            return x

        self.inputs = tf.keras.layers.Lambda(
            cast_and_concat
        )(preprocessed_inputs)

        critic_score = self._score_fn(
            output_size=output_shapes,
        )(self.inputs)
        gradient = tf.gradients(critic_score,self.inputs)[0]
        self._gradient = tf.keras.Model(inputs_flat,gradient)
        # self.gradient = tf.keras.Model()
        self.critic_model = tf.keras.Model(inputs_flat, critic_score)


        # # TODO move the logic outside this class
        # discriminate_loss = tf.keras.layers.Lambda(
        #     self._loss_fn
        # )(batch_size)

    @property
    def observation_keys(self):
        return self._observation_keys

    @property
    def input_names(self):
        return self.critic_model.input_names

    def reset(self):
        """Reset and clean the critis."""
        pass

    def get_weights(self):
        return self.critic_model.get_weights()

    def set_weights(self, *args, **kwargs):
        return self.critic_model.set_weights(*args, **kwargs)

    @property
    def trainable_variables(self):
        return self.critic_model.trainable_variables

    @abc.abstractmethod
    def discriminate(self, observations_p, observations_q):
        """Compute the score of p & q"""
        raise NotImplementedError

    def get_diagnostics(self, inputs):
        """Return diagnostic information of the critic.
        Arguments:
        Returns:
            diagnostics: OrderedDict of diagnostic information.
        """
        score = self.critic_model.predict(inputs)
        diagnostics = OrderedDict((
            ('logit_mean', np.mean(score)),
            ('logit_std', np.std(score)),
        ))
        return diagnostics

    def __getstate__(self):
        state = Serializable.__getstate__(self)
        state['pickled_weights'] = self.get_weights()
        return state

    def __setstate__(self, state):
        Serializable.__setstate__(self, state)
        self.set_weights(state['pickled_weights'])
        
class FeedforwardCritic(BaseCritic):
    def __init__(self,
                 hidden_layer_sizes,
                 *args,
                 sn=False,
                 activation='relu',
                 output_activation='linear',
                 name=None,
                 **kwargs):
        self._hidden_layer_sizes = hidden_layer_sizes
        self._activation = activation
        self._output_activation = output_activation
        self._name = name
        self._sn = sn

        self._Serializable__initialize(locals())
        super(FeedforwardCritic, self).__init__(*args, **kwargs)

    def _score_fn(self, output_size):
        if not self._sn:
            score_fn = feedforward_model(
                hidden_layer_sizes=self._hidden_layer_sizes,
                output_size=output_size,
                activation=self._activation,
                output_activation=self._output_activation)
        else:
            score_fn = feedforward_model_with_sn(
                hidden_layer_sizes=self._hidden_layer_sizes,
                output_size=output_size,
                activation=self._activation,
                output_activation=self._output_activation)

        return score_fn

